{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://github.com/wzyonggege/statistical-learning-method\n",
    "\n",
    "中文注释制作：机器学习初学者(微信公众号：ID:ai-start-com)\n",
    "\n",
    "配置环境：python 3.6\n",
    "\n",
    "代码全部测试通过。\n",
    "![gongzhong](../gongzhong.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第5章 决策树"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- ID3（基于信息增益）\n",
    "- C4.5（基于信息增益比）\n",
    "- CART（gini指数）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### entropy：$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n",
    "\n",
    "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n",
    "\n",
    "#### information gain : $g(D, A)=H(D)-H(D|A)$\n",
    "\n",
    "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H(A)}$\n",
    "\n",
    "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from collections import Counter\n",
    "import math\n",
    "from math import log\n",
    "\n",
    "import pprint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 书上题目5.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 书上题目5.1\n",
    "def create_data():\n",
    "    datasets = [['青年', '否', '否', '一般', '否'],\n",
    "               ['青年', '否', '否', '好', '否'],\n",
    "               ['青年', '是', '否', '好', '是'],\n",
    "               ['青年', '是', '是', '一般', '是'],\n",
    "               ['青年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '一般', '否'],\n",
    "               ['中年', '否', '否', '好', '否'],\n",
    "               ['中年', '是', '是', '好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['中年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '非常好', '是'],\n",
    "               ['老年', '否', '是', '好', '是'],\n",
    "               ['老年', '是', '否', '好', '是'],\n",
    "               ['老年', '是', '否', '非常好', '是'],\n",
    "               ['老年', '否', '否', '一般', '否'],\n",
    "               ]\n",
    "    labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n",
    "    # 返回数据集和每个维度的名称\n",
    "    return datasets, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets, labels = create_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.DataFrame(datasets, columns=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>年龄</th>\n",
       "      <th>有工作</th>\n",
       "      <th>有自己的房子</th>\n",
       "      <th>信贷情况</th>\n",
       "      <th>类别</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>青年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>一般</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>青年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>中年</td>\n",
       "      <td>是</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>中年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>是</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>老年</td>\n",
       "      <td>是</td>\n",
       "      <td>否</td>\n",
       "      <td>非常好</td>\n",
       "      <td>是</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>老年</td>\n",
       "      <td>否</td>\n",
       "      <td>否</td>\n",
       "      <td>一般</td>\n",
       "      <td>否</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    年龄 有工作 有自己的房子 信贷情况 类别\n",
       "0   青年   否      否   一般  否\n",
       "1   青年   否      否    好  否\n",
       "2   青年   是      否    好  是\n",
       "3   青年   是      是   一般  是\n",
       "4   青年   否      否   一般  否\n",
       "5   中年   否      否   一般  否\n",
       "6   中年   否      否    好  否\n",
       "7   中年   是      是    好  是\n",
       "8   中年   否      是  非常好  是\n",
       "9   中年   否      是  非常好  是\n",
       "10  老年   否      是  非常好  是\n",
       "11  老年   否      是    好  是\n",
       "12  老年   是      否    好  是\n",
       "13  老年   是      否  非常好  是\n",
       "14  老年   否      否   一般  否"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 熵\n",
    "def calc_ent(datasets):\n",
    "    data_length = len(datasets)\n",
    "    label_count = {}\n",
    "    for i in range(data_length):\n",
    "        label = datasets[i][-1]\n",
    "        if label not in label_count:\n",
    "            label_count[label] = 0\n",
    "        label_count[label] += 1\n",
    "    ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n",
    "    return ent\n",
    "\n",
    "# 经验条件熵\n",
    "def cond_ent(datasets, axis=0):\n",
    "    data_length = len(datasets)\n",
    "    feature_sets = {}\n",
    "    for i in range(data_length):\n",
    "        feature = datasets[i][axis]\n",
    "        if feature not in feature_sets:\n",
    "            feature_sets[feature] = []\n",
    "        feature_sets[feature].append(datasets[i])\n",
    "    cond_ent = sum([(len(p)/data_length)*calc_ent(p) for p in feature_sets.values()])\n",
    "    return cond_ent\n",
    "\n",
    "# 信息增益\n",
    "def info_gain(ent, cond_ent):\n",
    "    return ent - cond_ent\n",
    "\n",
    "def info_gain_train(datasets):\n",
    "    count = len(datasets[0]) - 1\n",
    "    ent = calc_ent(datasets)\n",
    "    best_feature = []\n",
    "    for c in range(count):\n",
    "        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))\n",
    "        best_feature.append((c, c_info_gain))\n",
    "        print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))\n",
    "    # 比较大小\n",
    "    best_ = max(best_feature, key=lambda x: x[-1])\n",
    "    return '特征({})的信息增益最大，选择为根节点特征'.format(labels[best_[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "特征(年龄) - info_gain - 0.083\n",
      "特征(有工作) - info_gain - 0.324\n",
      "特征(有自己的房子) - info_gain - 0.420\n",
      "特征(信贷情况) - info_gain - 0.363\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'特征(有自己的房子)的信息增益最大，选择为根节点特征'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info_gain_train(np.array(datasets))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "---\n",
    "\n",
    "利用ID3算法生成决策树，例5.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义节点类 二叉树\n",
    "class Node:\n",
    "    def __init__(self, root=True, label=None, feature_name=None, feature=None):\n",
    "        self.root = root\n",
    "        self.label = label\n",
    "        self.feature_name = feature_name\n",
    "        self.feature = feature\n",
    "        self.tree = {}\n",
    "        self.result = {'label:': self.label, 'feature': self.feature, 'tree': self.tree}\n",
    "\n",
    "    def __repr__(self):\n",
    "        return '{}'.format(self.result)\n",
    "\n",
    "    def add_node(self, val, node):\n",
    "        self.tree[val] = node\n",
    "\n",
    "    def predict(self, features):\n",
    "        if self.root is True:\n",
    "            return self.label\n",
    "        return self.tree[features[self.feature]].predict(features)\n",
    "    \n",
    "class DTree:\n",
    "    def __init__(self, epsilon=0.1):\n",
    "        self.epsilon = epsilon\n",
    "        self._tree = {}\n",
    "\n",
    "    # 熵\n",
    "    @staticmethod\n",
    "    def calc_ent(datasets):\n",
    "        data_length = len(datasets)\n",
    "        label_count = {}\n",
    "        for i in range(data_length):\n",
    "            label = datasets[i][-1]\n",
    "            if label not in label_count:\n",
    "                label_count[label] = 0\n",
    "            label_count[label] += 1\n",
    "        ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n",
    "        return ent\n",
    "\n",
    "    # 经验条件熵\n",
    "    def cond_ent(self, datasets, axis=0):\n",
    "        data_length = len(datasets)\n",
    "        feature_sets = {}\n",
    "        for i in range(data_length):\n",
    "            feature = datasets[i][axis]\n",
    "            if feature not in feature_sets:\n",
    "                feature_sets[feature] = []\n",
    "            feature_sets[feature].append(datasets[i])\n",
    "        cond_ent = sum([(len(p)/data_length)*self.calc_ent(p) for p in feature_sets.values()])\n",
    "        return cond_ent\n",
    "\n",
    "    # 信息增益\n",
    "    @staticmethod\n",
    "    def info_gain(ent, cond_ent):\n",
    "        return ent - cond_ent\n",
    "\n",
    "    def info_gain_train(self, datasets):\n",
    "        count = len(datasets[0]) - 1\n",
    "        ent = self.calc_ent(datasets)\n",
    "        best_feature = []\n",
    "        for c in range(count):\n",
    "            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))\n",
    "            best_feature.append((c, c_info_gain))\n",
    "        # 比较大小\n",
    "        best_ = max(best_feature, key=lambda x: x[-1])\n",
    "        return best_\n",
    "\n",
    "    def train(self, train_data):\n",
    "        \"\"\"\n",
    "        input:数据集D(DataFrame格式)，特征集A，阈值eta\n",
    "        output:决策树T\n",
    "        \"\"\"\n",
    "        _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[:-1]\n",
    "        # 1,若D中实例属于同一类Ck，则T为单节点树，并将类Ck作为结点的类标记，返回T\n",
    "        if len(y_train.value_counts()) == 1:\n",
    "            return Node(root=True,\n",
    "                        label=y_train.iloc[0])\n",
    "\n",
    "        # 2, 若A为空，则T为单节点树，将D中实例树最大的类Ck作为该节点的类标记，返回T\n",
    "        if len(features) == 0:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征\n",
    "        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))\n",
    "        max_feature_name = features[max_feature]\n",
    "\n",
    "        # 4,Ag的信息增益小于阈值eta,则置T为单节点树，并将D中是实例数最大的类Ck作为该节点的类标记，返回T\n",
    "        if max_info_gain < self.epsilon:\n",
    "            return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n",
    "\n",
    "        # 5,构建Ag子集\n",
    "        node_tree = Node(root=False, feature_name=max_feature_name, feature=max_feature)\n",
    "\n",
    "        feature_list = train_data[max_feature_name].value_counts().index\n",
    "        for f in feature_list:\n",
    "            sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)\n",
    "\n",
    "            # 6, 递归生成树\n",
    "            sub_tree = self.train(sub_train_df)\n",
    "            node_tree.add_node(f, sub_tree)\n",
    "\n",
    "        # pprint.pprint(node_tree.tree)\n",
    "        return node_tree\n",
    "\n",
    "    def fit(self, train_data):\n",
    "        self._tree = self.train(train_data)\n",
    "        return self._tree\n",
    "\n",
    "    def predict(self, X_test):\n",
    "        return self._tree.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets, labels = create_data()\n",
    "data_df = pd.DataFrame(datasets, columns=labels)\n",
    "dt = DTree()\n",
    "tree = dt.fit(data_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'label:': None, 'feature': 2, 'tree': {'否': {'label:': None, 'feature': 1, 'tree': {'否': {'label:': '否', 'feature': None, 'tree': {}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'否'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt.predict(['老年', '否', '否', '一般'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## sklearn.tree.DecisionTreeClassifier\n",
    "\n",
    "### criterion : string, optional (default=”gini”)\n",
    "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]\n",
    "\n",
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "from sklearn.tree import export_graphviz\n",
    "import graphviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
       "            max_features=None, max_leaf_nodes=None,\n",
       "            min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "            min_samples_leaf=1, min_samples_split=2,\n",
       "            min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n",
       "            splitter='best')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = DecisionTreeClassifier()\n",
    "clf.fit(X_train, y_train,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n",
    "with open('mytree.pdf') as f:\n",
    "    dot_graph = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"550pt\" height=\"477pt\"\r\n",
       " viewBox=\"0.00 0.00 550.00 477.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 473)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-473 546,-473 546,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"377.5,-469 273.5,-469 273.5,-401 377.5,-401 377.5,-469\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-453.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 5.45</text>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-438.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-423.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 70</text>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-408.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [35, 35]</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"316.5,-365 218.5,-365 218.5,-297 316.5,-297 316.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 2.85</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.234</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 37</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [32, 5]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M306.669,-400.884C301.807,-392.332 296.508,-383.013 291.423,-374.072\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"294.421,-372.262 286.435,-365.299 288.335,-375.722 294.421,-372.262\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"279.806\" y=\"-385.704\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 10 -->\r\n",
       "<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"432.5,-365 334.5,-365 334.5,-297 432.5,-297 432.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.45</text>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.165</text>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 33</text>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 30]</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;10 -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>0&#45;&gt;10</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M344.331,-400.884C349.193,-392.332 354.492,-383.013 359.577,-374.072\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"362.665,-375.722 364.565,-365.299 356.579,-372.262 362.665,-375.722\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"371.194\" y=\"-385.704\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200,-261 109,-261 109,-193 200,-193 200,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 4.7</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.32</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 4]</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M230.812,-296.884C220.648,-287.709 209.505,-277.65 198.95,-268.123\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"201.159,-265.402 191.39,-261.299 196.468,-270.598 201.159,-265.402\"/>\r\n",
       "</g>\r\n",
       "<!-- 5 -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"316.5,-261 218.5,-261 218.5,-193 316.5,-193 316.5,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 5.35</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.061</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 32</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [31, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;5 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>1&#45;&gt;5</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M267.5,-296.884C267.5,-288.778 267.5,-279.982 267.5,-271.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"271,-271.299 267.5,-261.299 264,-271.299 271,-271.299\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"91,-149.5 0,-149.5 0,-96.5 91,-96.5 91,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.111,-192.884C106.653,-181.226 92.6699,-168.141 80.2641,-156.532\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"82.4644,-153.797 72.7713,-149.52 77.6815,-158.908 82.4644,-153.797\"/>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200,-149.5 109,-149.5 109,-96.5 200,-96.5 200,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 4]</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>2&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M154.5,-192.884C154.5,-182.326 154.5,-170.597 154.5,-159.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158,-159.52 154.5,-149.52 151,-159.52 158,-159.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 6 -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"316.5,-149.5 218.5,-149.5 218.5,-96.5 316.5,-96.5 316.5,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 28</text>\r\n",
       "<text text-anchor=\"middle\" x=\"267.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [28, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;6 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>5&#45;&gt;6</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M267.5,-192.884C267.5,-182.326 267.5,-170.597 267.5,-159.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"271,-159.52 267.5,-149.52 264,-159.52 271,-159.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 7 -->\r\n",
       "<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"426,-157 335,-157 335,-89 426,-89 426,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"380.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"380.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.375</text>\r\n",
       "<text text-anchor=\"middle\" x=\"380.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"380.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 5&#45;&gt;7 -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>5&#45;&gt;7</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M304.188,-192.884C314.352,-183.709 325.495,-173.65 336.05,-164.123\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"338.532,-166.598 343.61,-157.299 333.841,-161.402 338.532,-166.598\"/>\r\n",
       "</g>\r\n",
       "<!-- 8 -->\r\n",
       "<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"371,-53 280,-53 280,-0 371,-0 371,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"325.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
       "</g>\r\n",
       "<!-- 7&#45;&gt;8 -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>7&#45;&gt;8</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M361.264,-88.9485C356.206,-80.2579 350.736,-70.8608 345.633,-62.0917\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"348.534,-60.1189 340.479,-53.2367 342.484,-63.6401 348.534,-60.1189\"/>\r\n",
       "</g>\r\n",
       "<!-- 9 -->\r\n",
       "<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"480,-53 389,-53 389,-0 480,-0 480,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"434.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"434.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"434.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 7&#45;&gt;9 -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>7&#45;&gt;9</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M399.387,-88.9485C404.353,-80.2579 409.722,-70.8608 414.733,-62.0917\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"417.871,-63.6557 419.793,-53.2367 411.793,-60.1826 417.871,-63.6557\"/>\r\n",
       "</g>\r\n",
       "<!-- 11 -->\r\n",
       "<g id=\"node12\" class=\"node\"><title>11</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"432.5,-253.5 334.5,-253.5 334.5,-200.5 432.5,-200.5 432.5,-253.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-238.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-223.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 30</text>\r\n",
       "<text text-anchor=\"middle\" x=\"383.5\" y=\"-208.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 30]</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;11 -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>10&#45;&gt;11</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M383.5,-296.884C383.5,-286.326 383.5,-274.597 383.5,-263.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"387,-263.52 383.5,-253.52 380,-263.52 387,-263.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 12 -->\r\n",
       "<g id=\"node13\" class=\"node\"><title>12</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"542,-253.5 451,-253.5 451,-200.5 542,-200.5 542,-253.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"496.5\" y=\"-238.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"496.5\" y=\"-223.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"496.5\" y=\"-208.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 0]</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;12 -->\r\n",
       "<g id=\"edge12\" class=\"edge\"><title>10&#45;&gt;12</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M420.188,-296.884C433.103,-285.226 447.599,-272.141 460.46,-260.532\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"463.15,-262.819 468.228,-253.52 458.46,-257.622 463.15,-262.819\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x22d67bed7b8>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graphviz.Source(dot_graph)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
